import os, hashlib
import requests
from tqdm import tqdm

import torch
import torch.distributed as dist

URL_MAP = {
    "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
}

CKPT_MAP = {
    "vgg_lpips": "vgg.pth"
}

MD5_MAP = {
    "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
}


def download(url, local_path, chunk_size=1024):
    os.makedirs(os.path.split(local_path)[0], exist_ok=True)
    with requests.get(url, stream=True) as r:
        total_size = int(r.headers.get("content-length", 0))
        with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
            with open(local_path, "wb") as f:
                for data in r.iter_content(chunk_size=chunk_size):
                    if data:
                        f.write(data)
                        pbar.update(chunk_size)


def md5_hash(path):
    with open(path, "rb") as f:
        content = f.read()
    return hashlib.md5(content).hexdigest()


def get_ckpt_path(name, root, check=False):
    assert name in URL_MAP
    path = os.path.join(root, CKPT_MAP[name])
    if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
        print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
        download(URL_MAP[name], path)
        md5 = md5_hash(path)
        assert md5 == MD5_MAP[name], md5
    return path


class KeyNotFoundError(Exception):
    def __init__(self, cause, keys=None, visited=None):
        self.cause = cause
        self.keys = keys
        self.visited = visited
        messages = list()
        if keys is not None:
            messages.append("Key not found: {}".format(keys))
        if visited is not None:
            messages.append("Visited: {}".format(visited))
        messages.append("Cause:\n{}".format(cause))
        message = "\n".join(messages)
        super().__init__(message)


def retrieve(
    list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
):
    """Given a nested list or dict return the desired value at key expanding
    callable nodes if necessary and :attr:`expand` is ``True``. The expansion
    is done in-place.

    Parameters
    ----------
        list_or_dict : list or dict
            Possibly nested list or dictionary.
        key : str
            key/to/value, path like string describing all keys necessary to
            consider to get to the desired value. List indices can also be
            passed here.
        splitval : str
            String that defines the delimiter between keys of the
            different depth levels in `key`.
        default : obj
            Value returned if :attr:`key` is not found.
        expand : bool
            Whether to expand callable nodes on the path or not.

    Returns
    -------
        The desired value or if :attr:`default` is not ``None`` and the
        :attr:`key` is not found returns ``default``.

    Raises
    ------
        Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
        ``None``.
    """

    keys = key.split(splitval)

    success = True
    try:
        visited = []
        parent = None
        last_key = None
        for key in keys:
            if callable(list_or_dict):
                if not expand:
                    raise KeyNotFoundError(
                        ValueError(
                            "Trying to get past callable node with expand=False."
                        ),
                        keys=keys,
                        visited=visited,
                    )
                list_or_dict = list_or_dict()
                parent[last_key] = list_or_dict

            last_key = key
            parent = list_or_dict

            try:
                if isinstance(list_or_dict, dict):
                    list_or_dict = list_or_dict[key]
                else:
                    list_or_dict = list_or_dict[int(key)]
            except (KeyError, IndexError, ValueError) as e:
                raise KeyNotFoundError(e, keys=keys, visited=visited)

            visited += [key]
        # final expansion of retrieved value
        if expand and callable(list_or_dict):
            list_or_dict = list_or_dict()
            parent[last_key] = list_or_dict
    except KeyNotFoundError as e:
        if default is None:
            raise e
        else:
            list_or_dict = default
            success = False

    if not pass_success:
        return list_or_dict
    else:
        return list_or_dict, success


def pixel_unshuffle(input, out_size):
    """
    Resize the given input to the given size

    Args:
        input: 4D tensor, B x c x H x W
        out_size: (H/r, W/r)
    
    return:
        output: [B x c*r^2, H/r W/r]

    """
    b, c1, h1, w1 = input.shape
    h2, w2 = out_size[0], out_size[1]
    assert h1 % h2 == 0 and w1 % w2 == 0, "This resize function is only support divisible resize!"

    kh, kw = int(h1/h2), int(w1/w2)
    sh, sw = int(h1/h2), int(w1/w2)


    output = torch.nn.functional.unfold(input, kernel_size=(kh, kw), stride=(sh, sw)) # B x kh*kw8C1 x size[0]*size[1]

    # output = output[:, 0::kh*kw, :] # B x 3 x size[0]*size[1]
    output = torch.nn.functional.fold(output, output_size=out_size, kernel_size=(1, 1),stride=(1, 1))

    return output



def pixel_shuffle(input, out_size):
    """
    Resize the given input to the given size

    Args:
        input: 4D tensor, B x c x H x W
        out_size: (H*r, W*r)
    
    return:
        output: [B x c/r^2, H*r W*r]
    """
    b, c1, h1, w1 = input.shape
    h2, w2 = out_size[0], out_size[1]
    assert h2 % h1 == 0 and w2 % w1 == 0, "This resize function is only support divisible resize!"
    assert h2 // h1 == w2 // w1, 'the resized scale factor should be the same along x and y dimension'

    s = h2 // h1 

    output = torch.nn.functional.pixel_shuffle(input, s)

    return output


def get_token_type(mask, token_shape):
    """
    Get the token type according to the given mask and token_shape.
    Note that we treat tokens into 3 types.
        0: masked tokens
        1: unmasked tokens
        2: partially masked tokens   

    Args:
        mask: 4D tensor, B x 1 x H x W, the mask of the origin image. 1 denotes masked pixles 
            and 0 denotes unmasked pixels.
        token_shape: [H/r, W/r]. the shape of token

    """
    mask_float = mask.float()

    mask_unshuffle = pixel_unshuffle(mask_float, token_shape) # B x r^2 x H/r x W/r

    scale_factor = mask_unshuffle.shape[1]
    mask_unshuffle = mask_unshuffle.sum(dim=1, keepdim=True) # B x 1 x H/r x W/r

    token_type = torch.zeros_like(mask_unshuffle).long() + 2
    
    token_type[mask_unshuffle==0] = 0 # unmasked tokens
    token_type[mask_unshuffle==scale_factor] = 1 # fully masked tokens

    return token_type

def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()

def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False):
    world_size = get_world_size()

    if world_size == 1:
        return tensor
    dist.all_reduce(tensor, op=op, async_op=async_op)

    return tensor

if __name__ == "__main__":
    config = {"keya": "a",
              "keyb": "b",
              "keyc":
                  {"cc1": 1,
                   "cc2": 2,
                   }
              }
    from omegaconf import OmegaConf
    config = OmegaConf.create(config)
    print(config)
    retrieve(config, "keya")

